{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import re\n",
    "import jieba\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from gensim.models import KeyedVectors\n",
    "from sklearn.model_selection import train_test_split\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from sklearn.model_selection import train_test_split\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, GRU, Embedding, LSTM, Bidirectional\n",
    "from keras.optimizers import Adam\n",
    "from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "300\n",
      "Vocab(count:255362, index:4574)\n"
     ]
    }
   ],
   "source": [
    "# 使用gensim加载已经训练好的汉语词向量\n",
    "cn_model = KeyedVectors.load_word2vec_format('sgns.zhihu.bigram', binary=False)\n",
    "# 词向量嵌入维度\n",
    "embedding_dim = cn_model['清华'].shape[0]\n",
    "print(embedding_dim)\n",
    "# 词向量模型包含的词语数量255362个，‘清华’为第4574个（索引）\n",
    "voc=cn_model.vocab['清华']\n",
    "print(voc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2000\n",
      "2000\n"
     ]
    }
   ],
   "source": [
    "# 将已经分类好的每条评论的路径放进列表\n",
    "pos_txts = os.listdir('pos')\n",
    "neg_txts = os.listdir('neg')\n",
    "# 好评与差评分别都是2000条\n",
    "print(len(pos_txts))\n",
    "print(len(neg_txts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将所有的评论内容放置到一个list里，列表中的每个元素是一条评论\n",
    "train_texts_orig = [] \n",
    "for i in range(len(pos_txts)):\n",
    "    with open('pos/'+pos_txts[i], 'r', errors='ignore') as f:\n",
    "        text = f.read().strip()\n",
    "        train_texts_orig.append(text)\n",
    "        f.close()\n",
    "for i in range(len(neg_txts)):\n",
    "    with open('neg/'+neg_txts[i], 'r', errors='ignore') as f:\n",
    "        text = f.read().strip()\n",
    "        train_texts_orig.append(text)\n",
    "        f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 元素长度为4000\n",
    "len(train_texts_orig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache C:\\Users\\46124\\AppData\\Local\\Temp\\jieba.cache\n",
      "Loading model cost 0.880 seconds.\n",
      "Prefix dict has been built succesfully.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4000\n"
     ]
    }
   ],
   "source": [
    "# 对每条评论进行去标点符号处理，然后利用jieba进行中文分词，最后将分出来的每个词转换为词向量中的对应索引，若该词不在词向量中，则索引标记为0，\n",
    "# 最后将每条评论转换为了词索引的列表\n",
    "train_tokens = []\n",
    "for text in train_texts_orig:\n",
    "    # 去掉标点\n",
    "    text = re.sub(\n",
    "        \"[\\s+\\.\\!\\/_,-|$%^*(+\\\"\\')]+|[+——！，； 。？ 、~@#￥%……&*（）]+\", \"\", text)\n",
    "    cut = jieba.cut(text)\n",
    "    cut_list = [i for i in cut]\n",
    "    for i, word in enumerate(cut_list):\n",
    "        try:\n",
    "            # 将词转换为索引index\n",
    "            cut_list[i] = cn_model.vocab[word].index\n",
    "        except KeyError:\n",
    "            # 如果词不在词向量中，则输出0\n",
    "            cut_list[i] = 0\n",
    "    train_tokens.append(cut_list)\n",
    "# print(train_tokens)\n",
    "print(len(train_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4000\n"
     ]
    }
   ],
   "source": [
    "# 因为每段评语的长度是不一样的，我们如果单纯取最长的一个评语，并把其他评填充成同样的长度，这样十分浪费计算资源，\n",
    "# 所以我们取一个折衷的长度，即将索引长度标准化\n",
    "# 获得每条评论的长度，即分词后词语的个数\n",
    "num_tokens = [ len(tokens) for tokens in train_tokens ]\n",
    "# print(num_tokens)\n",
    "print(len(num_tokens))\n",
    "# 将列表转换为ndarray格式\n",
    "num_tokens = np.array(num_tokens)\n",
    "# print(num_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "68.77625"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 平均评论的长度\n",
    "np.mean(num_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1438"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 最长的评价的长度\n",
    "np.max(num_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEWCAYAAACXGLsWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHWpJREFUeJzt3XmcHVWd9/HPl7AGwpqAJCE0aNgHECMyggrC47ApPCPrCAaIk2FEQMBHEkFZHn2EwUHAPayRHREkgjIgkgFfrAlbAojmFSCEAAkgEJYBgr/njzpNbi63u6qXe6u67/f9et1X3zpVdc6vq7vvr+ucqlOKCMzMzLqzXNkBmJlZ9TlZmJlZLicLMzPL5WRhZma5nCzMzCyXk4WZmeVysrBCJP1c0rf7qa4xkl6XNCQtT5f0lf6oO9X3e0nj+6u+HrT7XUkvSnq+H+raWdL8/oirDzGEpI+U0G7p37t9kJOFIekpSW9JWizpFUl3STpS0vu/HxFxZET834J17dbdNhExLyJWi4j3+iH2UyVdVlf/HhExta919zCODYATgC0i4kMN1vsDsAtlJSXrGScL6/T5iBgGbAicAZwIXNjfjUhavr/rrIgNgZciYmHZgZg1g5OFLSMiXo2IacCBwHhJWwFIukTSd9P74ZJuTGchL0u6U9Jyki4FxgC/Td1M35TUkf5znCBpHvDHmrLaxPFhSfdJelXSDZLWTm194D/yzrMXSbsD3wIOTO09nNa/362V4jpZ0tOSFkr6paQ10rrOOMZLmpe6kE7q6thIWiPtvyjVd3KqfzfgVmBkiuOSuv1WBX5fs/51SSMlrSTpHEkL0uscSSt10fYxkh6TNDot7y3poZozwa3rjs83JD2SjufVklbu7mfXza9EZ50rSfpBOk4vpG7JVWp/RpJOSMf4OUmH1+y7jqTfSnpN0v2pu+5Pad0dabOH03E5sGa/hvVZOZwsrKGIuA+YD3yqweoT0roRwHpkH9gREYcC88jOUlaLiP+o2eczwObAP3XR5JeBI4CRwBLgvAIx3gz8P+Dq1N42DTY7LL12ATYGVgN+XLfNTsCmwK7AdyRt3kWTPwLWSPV8JsV8eET8AdgDWJDiOKwuzjfq1q8WEQuAk4AdgG2BbYDtgZPrG1U2VnQY8JmImC9pO+Ai4N+AdYBfANPqEs0BwO7ARsDWaX/o4mfXxfdb60xgkxTrR4BRwHdq1n8oHZtRwATgJ5LWSut+AryRthmfXp3H5tPp7TbpuFxdoD4rgZOFdWcBsHaD8neB9YENI+LdiLgz8icZOzUi3oiIt7pYf2lEzE4frN8GDlAaAO+jLwFnR8TciHgdmAwcVHdWc1pEvBURDwMPk31wLyPFciAwOSIWR8RTwH8Ch/YxttMjYmFELAJOq6tPks4mS7C7pG0A/hX4RUTcGxHvpfGZt8kST6fzImJBRLwM/JbsQx568bOTpNTmcRHxckQsJkvSB9Vs9m76Xt6NiN8BrwObpuP2ReCUiHgzIh4DiownNayvwH7WJE4W1p1RwMsNys8C5gC3SJoraVKBup7pwfqngRWA4YWi7N7IVF9t3cuT/VfdqfbqpTfJzj7qDQdWbFDXqH6ObWTN8prAROD7EfFqTfmGwAmpK+kVSa8AG9Tt29X31Juf3QhgKDCzpr2bU3mnlyJiSYM2R5Ad79qfb97vQnf1WUmcLKwhSR8n+yD8U/269J/1CRGxMfB54HhJu3au7qLKvDOPDWrejyH7z/JFsu6LoTVxDWHZD6m8eheQfbjW1r0EeCFnv3ovppjq63q24P6N4mwU24Ka5b8BewMXS9qxpvwZ4HsRsWbNa2hEXJkbRPc/u668CLwFbFnT3hoRUeTDexHZ8R5dU7ZBF9tahTlZ2DIkrS5pb+Aq4LKImNVgm70lfSR1T7wGvJdekH0Ib9yLpg+RtIWkocDpwLXp0tq/ACtL2kvSCmR9+rV98y8AHd0M0l4JHCdpI0mrsXSMY0kX2zeUYrkG+J6kYZI2BI4HLut+z2XiXKdzcL0mtpMljZA0nGwMoP4y4Olk3VXXS/pEKj4fOFLSJ5RZNR2fYXlB5PzsGoqIv6c2fyhp3VTPKEldjT/V7vsecB1wqqShkjYjG+up1dvfGWshJwvr9FtJi8n+az0JOBvo6gqUscAfyPqR7wZ+mj7UAL5P9gH4iqRv9KD9S4FLyLpPVgaOgezqLOCrwAVk/8W/QTZA2+lX6etLkh5oUO9Fqe47gCeB/wGO7kFctY5O7c8lO+O6ItWfKyL+TJYc5qZjMxL4LjADeASYBTyQyur3vZXsZzFN0sciYgbZGMKPyc4+5rB0ADtPdz+77pyY2rlH0mupjqJjCF8jG6x+nuxncSXZGEunU4Gp6bgcULBOazH54Udm1kqSzgQ+FBEtv8vees9nFmbWVJI2k7R16jLbnuxS2OvLjst6ZrDeTWtm1TGMrOtpJLCQ7JLjG0qNyHrM3VBmZpbL3VBmZpZrQHdDDR8+PDo6OsoOw8xsQJk5c+aLETEif8ulBnSy6OjoYMaMGWWHYWY2oEh6On+rZbkbyszMcjlZmJlZLicLMzPL5WRhZma5nCzMzCyXk4WZmeVysjAzs1xOFmZmlsvJwszMcjlZ2IDQMekmOibdVHYYZm3LycLMzHI5WZiZWS4nC6scdzmZVY+ThZmZ5XKyMDOzXAP6eRY2cNV2Mz11xl4lRmJmRfjMwszMcjlZ2KDjAXKz/udkYWZmuZwszMwsl5OFmZnlalqykHSRpIWSZteUnSXpz5IekXS9pDVr1k2WNEfSE5L+qVlxmZlZzzXzzOISYPe6sluBrSJia+AvwGQASVsABwFbpn1+KmlIE2MzM7MeaFqyiIg7gJfrym6JiCVp8R5gdHq/D3BVRLwdEU8Cc4DtmxWbmZn1TJljFkcAv0/vRwHP1Kybn8rMzKwCSrmDW9JJwBLg8s6iBptFF/tOBCYCjBkzpinxWe/4rmyzwavlZxaSxgN7A1+KiM6EMB/YoGaz0cCCRvtHxJSIGBcR40aMGNHcYM3MDGhxspC0O3Ai8IWIeLNm1TTgIEkrSdoIGAvc18rYzMysa03rhpJ0JbAzMFzSfOAUsqufVgJulQRwT0QcGRGPSroGeIyse+qoiHivWbGZmVnPNC1ZRMTBDYov7Gb77wHfa1Y8Vl2dYx0e5zCrLt/BbWZmuZwszMwsl5OFmZnl8pPyrDKKPIOifnzD4x1mreEzCzMzy+VkYWZmuZwszMwsl5OFmZnlcrIwM7NcThZmZpbLycLMzHI5Wdig1zHppkL3cJhZ13xTnjVV/Ye0b54zG5h8ZmFmZrmcLMzMLJeThZmZ5XKyMDOzXE4WZmaWy8nCzMxyOVmYmVkuJwszM8vlZGFmZrl8B7cNSJ6+w6y1mnZmIekiSQslza4pW1vSrZL+mr6ulcol6TxJcyQ9Imm7ZsVlZmY918xuqEuA3evKJgG3RcRY4La0DLAHMDa9JgI/a2JcZmbWQ03rhoqIOyR11BXvA+yc3k8FpgMnpvJfRkQA90haU9L6EfFcs+Kz9lXbheWJDc2KafUA93qdCSB9XTeVjwKeqdlufiozM7MKqMrVUGpQFg03lCZKmiFpxqJFi5oclpmZQeuvhnqhs3tJ0vrAwlQ+H9igZrvRwIJGFUTEFGAKwLhx4xomFGs/vjrKrLlafWYxDRif3o8Hbqgp/3K6KmoH4FWPV5iZVUfTziwkXUk2mD1c0nzgFOAM4BpJE4B5wP5p898BewJzgDeBw5sVl5mZ9Vwzr4Y6uItVuzbYNoCjmhWL9Vxnt067XS3Urt+3WZ6qDHCbmVmFOVlYW+uYdJMHx80KcLIwM7NcuclC0o6SVk3vD5F0tqQNmx+amZlVRZEzi58Bb0raBvgm8DTwy6ZGZWZmlVIkWSxJVyvtA5wbEecCw5oblpmZVUmRS2cXS5oMHAJ8WtIQYIXmhmXWdx64Nus/Rc4sDgTeBiZExPNkE/yd1dSozMysUnLPLFKCOLtmeR4eszAzaytFrob65/Rku1clvSZpsaTXWhGcmZlVQ5Exi/8APh8Rjzc7GLNm8hiGWe8VGbN4wYnCzKy9FTmzmCHpauA3ZAPdAETEdU2LyirDjyA1MyiWLFYnmzb8czVlAThZmJm1iSJXQ/nZEjboeTzDrHtFrobaRNJtkman5a0lndz80MzMrCqKDHCfD0wG3gWIiEeAg5oZlJmZVUuRZDE0Iu6rK1vSjGDMzKyaiiSLFyV9mGxQG0n7Ac81NSozM6uUIldDHQVMATaT9CzwJNmkgmZm1iaKJItnI2K39ACk5SJisaS1mx2Yma9QMquOIt1Q10laPiLeSIniQ8CtzQ7MzMyqo0iy+A1wraQhkjqAW8iujjIzszZR5Ka88yWtSJY0OoB/i4i7+tKopOOAr5ANms8CDgfWB64C1gYeAA6NiHf60o6ZmfWPLs8sJB3f+QJWBjYAHgJ2SGW9ImkUcAwwLiK2AoaQ3bdxJvDDiBgL/A2Y0Ns2zMysf3XXDTWs5rUacD0wp6asL5YHVpG0PDCU7FLczwLXpvVTgX372IaZmfWTLruhIuK02mVJw7LieL0vDUbEs5J+AMwD3iIbA5kJvBIRnTf7zSd7fOsHSJoITAQYM2ZMX0IxM7OCiswNtZWkB4HZwKOSZkrasrcNSloL2AfYCBgJrArs0WDTaLR/REyJiHERMW7EiBG9DcPMzHqgyH0WU4DjI+J2AEk7k80X9cletrkb8GRELEr1XZfqWjNdorsEGA0s6GX91gfd3dvQuc7PtTBrP0UunV21M1EARMR0srOB3ppHNkg+VJKAXYHHgNuB/dI244Eb+tCGmZn1oyLJYq6kb0vqSK+Tyab86JWIuJdsIPsBsstmlyM7ezkROF7SHGAd4MLetmFmZv2rSDfUEcBpLH0y3h3AYX1pNCJOAU6pK54LbN+Xes3MrDmKJIvdIuKY2gJJ+wO/ak5IZmZWNUWSxWQ+mBgalZkNGo0G+j2wb+2sy2QhaQ9gT2CUpPNqVq2OH3406HiG13y+GszaWXdnFguAGcAXyG6a67QYOK6ZQZmZWbV0dwf3w8DDkq6IiHdbGJOZmVVM7qWzThRmZlbkPguzZXRMusljHGZtprspyi9NX49tXThmZlZF3Z1ZfEzShsARktaStHbtq1UBmplZ+bq7GurnwM3AxmRXQ6lmXaRyMzNrA12eWUTEeRGxOXBRRGwcERvVvJworG15zMbaUZFncP+7pG2AT6WiOyLikeaGZWZmVVLk4UfHAJcD66bX5ZKObnZgZmZWHUXmhvoK8ImIeANA0pnA3cCPmhmYNU9tF4qnrjCzIorcZyHgvZrl91h2sNvMzAa5ImcWFwP3Sro+Le+LH0xkZtZWigxwny1pOrAT2RnF4RHxYLMDMzOz6ihyZkFEPED2GFSzD/BlpGaDn+eGMjOzXE4WZmaWq9tkIWmIpD+0KhgzM6umbpNFRLwHvClpjRbFY2ZmFVRkgPt/gFmSbgXe6CyMiGOaFpWZmVVKkWRxU3r1G0lrAhcAW5HNYHsE8ARwNdABPAUcEBF/6892zcysd4rcZzFV0irAmIh4op/aPRe4OSL2k7QiMBT4FnBbRJwhaRIwCTixn9ozM7M+KDKR4OeBh8iebYGkbSVN622DklYHPk26Czwi3omIV4B9gKlps6lkd4qbmVkFFLl09lRge+AVgIh4CNioD21uDCwCLpb0oKQLJK0KrBcRz6U2niOb4fYDJE2UNEPSjEWLFvUhDDMzK6pIslgSEa/WlUUf2lwe2A74WUR8lGzQfFLRnSNiSkSMi4hxI0aM6EMYZmZWVJFkMVvSvwBDJI2V9CPgrj60OR+YHxH3puVryZLHC5LWB0hfF/ahDbOm8xPzrJ0USRZHA1sCbwNXAq8BX+9tgxHxPPCMpE1T0a7AY8A0YHwqGw/c0Ns2zMysfxW5GupN4KT00KOIiMX90O7RZE/cWxGYCxxOlriukTQBmAfs3w/tmJlZP8hNFpI+DlwEDEvLrwJHRMTM3jaaBsnHNVi1a2/rtIHB3TZmA1ORm/IuBL4aEXcCSNqJ7IFIWzczMDMzq44iyWJxZ6IAiIg/SeqPrigb4HyWkOk8Dn6euQ1mXSYLSdult/dJ+gXZ4HYABwLTmx+amZlVRXdnFv9Zt3xKzfu+3GdhZmYDTJfJIiJ2aWUgZmZWXUWuhloT+DLZbLDvb+8pys265nEMG2yKDHD/DrgHmAX8vbnhmJlZFRVJFitHxPFNj8QGFV8pZTa4FJnu41JJ/yppfUlrd76aHpmZmVVGkTOLd4CzgJNYehVUkE01bmZmbaBIsjge+EhEvNjsYMzMrJqKdEM9CrzZ7EDMzKy6ipxZvAc8JOl2smnKAV86a2bWTooki9+kl5mZtakiz7OY2opAzMysuorcwf0kDeaCighfDWWWw3dy22BRpBuq9iFFK5M9wc73WZiZtZHcq6Ei4qWa17MRcQ7w2RbEZmZmFVGkG2q7msXlyM40hjUtIrMBylOc2GBWpBuq9rkWS4CngAOaEo2ZmVVSkauh/FwLM7M2V6QbaiXgi3zweRanNy8sMzOrkiLdUDcArwIzqbmD28zM2keRZDE6Inbv74YlDQFmAM9GxN6SNgKuIrss9wHg0Ih4p7/bNTOznisykeBdkv6hCW0fCzxes3wm8MOIGAv8DZjQhDbNzKwXiiSLnYCZkp6Q9IikWZIe6UujkkYDewEXpGWR3btxbdpkKrBvX9owM7P+U6Qbao8mtHsO8E2W3q+xDvBKRCxJy/OBUY12lDQRmAgwZsyYJoRmZmb1ilw6+3R/Nihpb2BhRMyUtHNncaOmu4hnCjAFYNy4cQ23MTOz/lXkzKK/7Qh8QdKeZHNNrU52prGmpOXT2cVoYEEJsZmZWQNFxiz6VURMjojREdEBHAT8MSK+BNwO7Jc2G092ya6ZmVVAy5NFN04Ejpc0h2wM48KS4zEzs6SMbqj3RcR0YHp6PxfYvsx4zMyssVKThbWWZ0Utjx+CZANdlbqhzMysopwszMwsl5OFmZnlcrIwM7NcThZmZpbLycLMzHI5WZi1UMekm3wJsw1IThZmZpbLycLMzHI5WZiZWS4nCzMzy+VkYWZmuZwszMwsl5OFmZnlcrIwK5nvvbCBwM+zMCtBd8nBz76wKvKZhZmZ5XKyMDOzXE4WZmaWy8nCzMxyOVmYmVkuJwszM8vV8mQhaQNJt0t6XNKjko5N5WtLulXSX9PXtVodm5mZNVbGmcUS4ISI2BzYAThK0hbAJOC2iBgL3JaWzcysAlqeLCLiuYh4IL1fDDwOjAL2AaamzaYC+7Y6NjMza6zUO7gldQAfBe4F1ouI5yBLKJLW7WKficBEgDFjxrQmULMWqL+r23dyW5WUNsAtaTXg18DXI+K1ovtFxJSIGBcR40aMGNG8AM3M7H2lJAtJK5Alissj4rpU/IKk9dP69YGFZcRmVmWedNDK0vJuKEkCLgQej4iza1ZNA8YDZ6SvN7Q6NrMqcnKwKihjzGJH4FBglqSHUtm3yJLENZImAPOA/UuIzczMGmh5soiIPwHqYvWurYxlsPMAqZn1F9/BbTaAeQzDWsXJwszMcvlJeWYDkM8mrNV8ZmFmZrmcLMzMLJeThZmZ5XKyMDOzXE4WZoOAL6G1ZnOyMDOzXE4WZmaWy8nCzMxy+aY8s0GkdtzCc4JZf/KZhZmZ5XKyMDOzXE4WZmaWy8nCzMxyeYDbbJCqv0nPA97WFz6zMDOzXE4WZubpQiyXu6EqpL+eme0/emvEvxfWFz6zMDOzXE4WZvYB7payek4WZmaWq3JjFpJ2B84FhgAXRMQZJYfUr/p77p7+Gucwg67HNTznlFUqWUgaAvwE+F/AfOB+SdMi4rFWx1K1Pw4nBStDX7ui8n5vq/Z3Zl2rWjfU9sCciJgbEe8AVwH7lByTmVnbU0SUHcP7JO0H7B4RX0nLhwKfiIiv1WwzEZiYFrcCZrc80GoaDrxYdhAV4WOxlI/FUj4WS20aEcN6skOluqEANShbJptFxBRgCoCkGRExrhWBVZ2PxVI+Fkv5WCzlY7GUpBk93adq3VDzgQ1qlkcDC0qKxczMkqoli/uBsZI2krQicBAwreSYzMzaXqW6oSJiiaSvAf9FdunsRRHxaDe7TGlNZAOCj8VSPhZL+Vgs5WOxVI+PRaUGuM3MrJqq1g1lZmYV5GRhZma5BmyykLS7pCckzZE0qex4yiJpA0m3S3pc0qOSji07pjJJGiLpQUk3lh1L2SStKelaSX9Ovx//WHZMZZF0XPr7mC3pSkkrlx1Tq0i6SNJCSbNrytaWdKukv6ava+XVMyCTRc20IHsAWwAHS9qi3KhKswQ4ISI2B3YAjmrjYwFwLPB42UFUxLnAzRGxGbANbXpcJI0CjgHGRcRWZBfPHFRuVC11CbB7Xdkk4LaIGAvclpa7NSCTBZ4W5H0R8VxEPJDeLyb7QBhVblTlkDQa2Au4oOxYyiZpdeDTwIUAEfFORLxSblSlWh5YRdLywFDa6P6tiLgDeLmueB9gano/Fdg3r56BmixGAc/ULM+nTT8ga0nqAD4K3FtuJKU5B/gm8PeyA6mAjYFFwMWpW+4CSauWHVQZIuJZ4AfAPOA54NWIuKXcqEq3XkQ8B9k/nMC6eTsM1GSROy1Iu5G0GvBr4OsR8VrZ8bSapL2BhRExs+xYKmJ5YDvgZxHxUeANCnQ1DEapP34fYCNgJLCqpEPKjWrgGajJwtOC1JC0AlmiuDwiris7npLsCHxB0lNk3ZKflXRZuSGVaj4wPyI6zzKvJUse7Wg34MmIWBQR7wLXAZ8sOaayvSBpfYD0dWHeDgM1WXhakESSyPqlH4+Is8uOpywRMTkiRkdEB9nvwx8jom3/e4yI54FnJG2ainYFWv5cmIqYB+wgaWj6e9mVNh3srzENGJ/ejwduyNuhUtN9FNWLaUEGsx2BQ4FZkh5KZd+KiN+VGJNVw9HA5ekfqrnA4SXHU4qIuFfStcADZFcPPkgbTf0h6UpgZ2C4pPnAKcAZwDWSJpAl0/1z6/F0H2ZmlmegdkOZmVkLOVmYmVkuJwszM8vlZGFmZrmcLMzMLJeThQ1Ykl5vQp3bStqzZvlUSd/oQ337pxlfb68r75D0LwX2P0zSj3vbvll/cbIwW9a2wJ65WxU3AfhqROxSV94B5CYLs6pwsrBBQdL/kXS/pEcknZbKOtJ/9eenZxncImmVtO7jadu7JZ2VnnOwInA6cKCkhyQdmKrfQtJ0SXMlHdNF+wdLmpXqOTOVfQfYCfi5pLPqdjkD+FRq5zhJK0u6ONXxoKT65IKkvVK8wyWNkPTr9D3fL2nHtM2p6fkFy8QraVVJN0l6OMV4YH39Zt2KCL/8GpAv4PX09XNkd+SK7B+gG8mm5+4gu2N327TdNcAh6f1s4JPp/RnA7PT+MODHNW2cCtwFrAQMB14CVqiLYyTZXbAjyGZF+COwb1o3new5CvWx7wzcWLN8AnBxer9Zqm/lzniA/w3cCayVtrkC2Cm9H0M23UuX8QJfBM6vaW+Nsn9+fg2s14Cc7sOszufS68G0vBowluwD98mI6JwGZSbQIWlNYFhE3JXKrwD27qb+myLibeBtSQuB9cgm6uv0cWB6RCwCkHQ5WbL6TQ++h52AHwFExJ8lPQ1sktbtAowDPhdLZxTejeyMp3P/1SUN6ybeWcAP0lnPjRFxZw9iM3OysEFBwPcj4hfLFGbP93i7pug9YBUaT3Hfnfo66v9uelpfI93VMZfs+RSbADNS2XLAP0bEW8tUkiWPD8QbEX+R9DGy8ZjvS7olIk7vh7itTXjMwgaD/wKOSM/0QNIoSV0+zCUi/gYslrRDKqp9xOZiYNgH9+rWvcBn0ljCEOBg4L9z9qlv5w7gSyn+Tci6lp5I654G/hn4paQtU9ktwNc6d5a0bXeNSRoJvBkRl5E9CKhdpyu3XnKysAEvsqeeXQHcLWkW2bMb8j7wJwBTJN1N9l/9q6n8drLunYeKDgJH9qSxyWnfh4EHIiJvyudHgCVpwPk44KfAkBT/1cBhqSups40nyJLJryR9mPRM6TRI/xhwZE57/wDcl2YmPgn4bpHvzayTZ521tiRptYh4Pb2fBKwfEceWHJZZZXnMwtrVXpImk/0NPE121ZGZdcFnFmZmlstjFmZmlsvJwszMcjlZmJlZLicLMzPL5WRhZma5/j/NxM64cSIW9wAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 绘制评论长度直方图\n",
    "plt.hist(np.log(num_tokens), bins = 100)\n",
    "plt.xlim((0,10))\n",
    "plt.ylabel('number of tokens')\n",
    "plt.xlabel('length of tokens')\n",
    "plt.title('Distribution of tokens length')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "227"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 由上面直方图可看出评论长度大致服从正态分布，则评论长度可进行以下取值m，这个值可以涵盖95%左右的样本\n",
    "mid_tokens = np.mean(num_tokens) + 2 * np.std(num_tokens)\n",
    "mid_tokens = int(mid_tokens)\n",
    "mid_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.956"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取评论长度为227后，大约95%的样本被涵盖\n",
    "np.sum( num_tokens < mid_tokens ) / len(num_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 由于预训练的词向量词汇数为255362，比较大，为了节省训练时间，抽取前50000个词，构建新的词向量\n",
    "num_words = 50000\n",
    "# 初始化embedding_matrix，之后在keras上进行应用\n",
    "embedding_matrix = np.zeros((num_words, embedding_dim))\n",
    "# embedding_matrix为一个 [num_words，embedding_dim] 的矩阵\n",
    "# 维度为 50000 * 300\n",
    "for i in range(num_words):\n",
    "    embedding_matrix[i,:] = cn_model[cn_model.index2word[i]]\n",
    "embedding_matrix = embedding_matrix.astype('float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 1.160710e-01  3.823570e-01  3.350100e-02  6.509880e-01 -4.422000e-03\n",
      " -3.975420e-01 -3.907440e-01 -1.683600e-02  6.122400e-02  6.713800e-02\n",
      " -6.707850e-01  2.157060e-01 -5.070170e-01  7.741850e-01  4.844490e-01\n",
      " -4.285320e-01  1.542340e-01 -6.967980e-01  9.287580e-01  4.554930e-01\n",
      " -1.046007e+00  6.835810e-01 -1.509860e-01 -3.230890e-01 -4.166720e-01\n",
      "  8.421480e-01 -1.314240e-01 -2.506910e-01  6.844910e-01 -1.548800e-01\n",
      " -6.630620e-01  6.492340e-01  1.783220e-01 -2.903320e-01  7.403010e-01\n",
      "  6.675600e-02 -1.797410e-01  1.260600e-02 -7.401170e-01 -7.841450e-01\n",
      "  2.338000e-02  2.171590e-01  4.285500e-01 -4.258850e-01 -1.838450e-01\n",
      " -9.258050e-01 -9.325810e-01  3.774180e-01  9.589400e-02 -1.021126e+00\n",
      " -5.428410e-01  3.356190e-01  4.242400e-02  6.216920e-01 -1.315300e-01\n",
      " -8.639000e-03 -6.276220e-01  4.994250e-01  7.900700e-02 -4.089200e-02\n",
      "  6.278660e-01  1.197020e-01  2.085100e-02 -4.527590e-01  4.065140e-01\n",
      " -1.684710e-01  3.961030e-01  4.670930e-01 -6.273120e-01 -7.722130e-01\n",
      " -3.347990e-01  4.145790e-01 -5.500580e-01 -6.130720e-01  3.308020e-01\n",
      " -2.506440e-01 -1.058240e-01  3.655210e-01 -3.633430e-01 -3.346100e-01\n",
      "  2.015690e-01 -4.844690e-01  1.476390e-01 -2.433930e-01  5.673500e-02\n",
      " -2.821140e-01  7.345520e-01  7.194030e-01 -1.230510e-01 -6.388350e-01\n",
      " -6.253240e-01  5.781780e-01  7.626670e-01  8.289600e-02  1.370876e+00\n",
      " -2.285000e-02 -1.039510e-01  6.824250e-01 -1.930600e-02 -5.609150e-01\n",
      "  6.419920e-01 -1.494030e-01  8.059200e-02 -2.808490e-01  4.986720e-01\n",
      "  7.685460e-01 -3.646280e-01 -4.845320e-01 -1.107231e+00  6.448720e-01\n",
      "  1.874000e-03  4.998070e-01 -1.379960e-01  7.739700e-02  1.126077e+00\n",
      "  3.452320e-01  3.777990e-01 -2.545810e-01 -6.642710e-01  4.144310e-01\n",
      "  2.040630e-01  8.529700e-02  3.668600e-01  1.014550e-01 -1.303170e-01\n",
      "  3.807690e-01  4.263100e-02 -8.337220e-01 -9.495300e-02  1.061490e-01\n",
      "  2.937750e-01 -5.399930e-01  2.093250e-01  8.658110e-01 -2.148700e-01\n",
      " -5.258930e-01  4.637400e-02 -7.496300e-02 -1.093580e-01  5.242600e-02\n",
      "  7.138680e-01  8.699890e-01  1.096506e+00 -1.567000e-03 -1.182124e+00\n",
      " -6.967010e-01  4.209730e-01  8.549780e-01 -1.440820e-01  9.787120e-01\n",
      "  7.551090e-01  3.746130e-01  3.022730e-01 -1.112939e+00  8.567640e-01\n",
      " -2.772870e-01 -8.780370e-01 -7.638500e-02  1.695940e-01 -3.052300e-01\n",
      " -4.175490e-01 -4.382760e-01 -4.814640e-01  1.192990e-01  2.295200e-02\n",
      " -1.129100e-01 -9.305060e-01  5.229810e-01 -6.922720e-01 -2.130700e-01\n",
      " -5.198320e-01 -2.084170e-01 -4.429420e-01 -1.077630e-01 -1.689554e+00\n",
      " -7.761370e-01 -5.924380e-01  3.138350e-01  4.220150e-01  2.024260e-01\n",
      " -1.616980e-01  2.530190e-01  4.464500e-02 -6.125620e-01 -3.403300e-02\n",
      " -7.468480e-01 -1.536210e-01 -2.413320e-01 -4.949300e-02 -8.470000e-01\n",
      " -6.870700e-01  6.863710e-01 -7.252470e-01 -6.721740e-01 -3.911280e-01\n",
      " -1.466190e-01  3.246700e-01 -7.967330e-01 -1.397470e+00  5.409480e-01\n",
      " -2.543500e-02 -3.207220e-01 -1.195780e-01  4.670010e-01  8.050000e-03\n",
      " -1.423320e-01 -6.229020e-01 -1.535790e-01  6.506570e-01 -6.872760e-01\n",
      "  4.966810e-01 -2.963600e-01  2.267390e-01 -3.492640e-01 -6.970900e-02\n",
      " -2.070350e-01 -1.306400e-01 -3.248480e-01 -7.160300e-01  2.174630e-01\n",
      " -1.804640e-01  2.316510e-01 -1.086708e+00  3.346260e-01 -1.955490e-01\n",
      " -5.870700e-02  8.666000e-02 -5.290260e-01  7.620600e-02 -6.251110e-01\n",
      " -6.947800e-01 -1.442259e+00  9.090940e-01  1.144080e-01  8.196530e-01\n",
      "  2.428270e-01  8.581950e-01  9.534730e-01 -2.295000e-01 -2.297220e-01\n",
      "  6.468970e-01  7.454800e-02  9.851700e-01  7.378320e-01 -4.618830e-01\n",
      "  2.567200e-02  9.727200e-02 -1.553360e-01  1.813450e-01 -2.058300e-01\n",
      " -1.381550e-01  3.380190e-01 -5.999100e-02  5.723730e-01 -5.025500e-02\n",
      " -4.248700e-02 -4.839700e-01 -5.199490e-01  5.618170e-01  8.533330e-01\n",
      "  1.858700e-01  6.647300e-02 -1.163222e+00 -3.600480e-01  1.798240e-01\n",
      " -3.766590e-01 -6.445710e-01  3.945300e-01  1.015964e+00 -3.036340e-01\n",
      "  6.894950e-01 -9.391810e-01  3.086100e-02  2.383370e-01 -7.463050e-01\n",
      " -1.491230e-01  1.700980e-01  1.527770e-01 -1.734950e-01 -3.817440e-01\n",
      "  3.201760e-01  1.504552e+00  5.963760e-01 -2.580100e-01 -4.934740e-01\n",
      "  5.659830e-01 -1.079104e+00  2.311180e-01 -2.090500e-02 -2.059950e-01\n",
      " -1.036721e+00 -1.531850e-01  5.824100e-02  7.875850e-01  4.788590e-01\n",
      " -9.657600e-02 -5.051300e-01 -8.506700e-01  3.563480e-01 -4.049710e-01]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "300"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 检查index是否对应，\n",
    "# 输出300意义为新构建的词向量与预训练的词向量一一对应\n",
    "print(cn_model[cn_model.index2word[222]])\n",
    "# print(embedding_matrix[222])\n",
    "np.sum( cn_model[cn_model.index2word[222]] == embedding_matrix[222] )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000, 300)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 新建词向量的维度，这个维度为keras的要求，后续会在模型中用到\n",
    "embedding_matrix.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 进行padding（填充）和truncating（裁剪）， 输入的train_tokens是一个list，返回的train_pad是一个numpy array，采用前面（pre）填充的方式\n",
    "train_pad = pad_sequences(train_tokens, maxlen=mid_tokens,\n",
    "                            padding='pre', truncating='pre')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 超出五万个词向量的词用0代替\n",
    "train_pad[ train_pad>=num_words ] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([    0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "           0,     0,     0,     0,     0,  3053,    57,   169,    73,\n",
       "           1,    25, 11216,    49,   163, 15985,     0,     0,    30,\n",
       "           8,     0,     1,   228,   223,    40,    35,   653,     0,\n",
       "           5,  1642,    29, 11216,  2751,   500,    98,    30,  3159,\n",
       "        2225,   371,  6285,   169, 27396,     1,  1191,  5432,  1080,\n",
       "       20055,    57,   562,     1, 22671,    40,    35,   169,  2567,\n",
       "           0, 42665,  7761,     0,     0, 41281,     0,     0, 35891,\n",
       "           0, 28781,    57,   169,  1419,     1, 11670,     0, 19470,\n",
       "           1,     0,     0,   169, 35071,    40,   562,    35, 12398,\n",
       "         657,  4857])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 可见padding之后前面的tokens全变成0，文本在最后面\n",
    "train_pad[33]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4000,)\n"
     ]
    }
   ],
   "source": [
    "# 准备实际输出结果向量向量，前2000好评的样本设为1，后2000差评样本设为0\n",
    "train_target = np.concatenate( (np.ones(2000),np.zeros(2000)) )\n",
    "print(train_target.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 90%的样本用来训练，剩余10%用来测试\n",
    "X_train, X_test, y_train, y_test = train_test_split(train_pad,\n",
    "                                                    train_target,\n",
    "                                                    test_size=0.1,\n",
    "                                                    random_state=12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                                                                                                                                                               房间很大还有海景阳台走出酒店就是沙滩非常不错唯一遗憾的就是不能刷 不方便\n",
      "class:  1.0\n"
     ]
    }
   ],
   "source": [
    "# 查看训练样本，确认无误，用索引反向生成语句，索引为零的标记为空格字符\n",
    "def reverse_tokens(tokens):\n",
    "    text = ''\n",
    "    for i in tokens:\n",
    "        if i != 0:\n",
    "            text = text + cn_model.index2word[i]\n",
    "        else:\n",
    "            text = text + ' '\n",
    "    return text\n",
    "\n",
    "print(reverse_tokens(X_train[35]))\n",
    "print('class: ',y_train[35])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 用keras构建顺序模型\n",
    "model = Sequential()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\46124\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    }
   ],
   "source": [
    "# 模型第一层为embedding\n",
    "model.add(Embedding(num_words,\n",
    "                    embedding_dim,\n",
    "                    weights=[embedding_matrix],\n",
    "                    input_length=mid_tokens,\n",
    "                    trainable=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 循环层使用两层LSTM长短期记忆网络，其中第一层为双向的\n",
    "model.add(Bidirectional(LSTM(units=32, return_sequences=True)))\n",
    "model.add(LSTM(units=16, return_sequences=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义全连接层，激活函数使用sigmoid\n",
    "model.add(Dense(1, activation='sigmoid'))\n",
    "# 梯度 下降使用adam算法，学习率设为0.01\n",
    "optimizer = Adam(lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义反向传播，使用交叉熵损失函数，评估函数使用平均值\n",
    "model.compile(loss='binary_crossentropy',\n",
    "              optimizer=optimizer,\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (None, 227, 300)          15000000  \n",
      "_________________________________________________________________\n",
      "bidirectional_1 (Bidirection (None, 227, 64)           85248     \n",
      "_________________________________________________________________\n",
      "lstm_2 (LSTM)                (None, 16)                5184      \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 1)                 17        \n",
      "=================================================================\n",
      "Total params: 15,090,449\n",
      "Trainable params: 90,449\n",
      "Non-trainable params: 15,000,000\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# 查看模型的结构，一共90k左右可训练的变量\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 建立一个权重的存储点，保存训练中的最好模型\n",
    "path_checkpoint = 'sentiment_checkpoint.keras'\n",
    "checkpoint = ModelCheckpoint(filepath=path_checkpoint, monitor='val_loss' , verbose=1 , \n",
    "save_weights_only=True , save_best_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unable to open file (unable to open file: name = 'sentiment_checkpoint.keras', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)\n"
     ]
    }
   ],
   "source": [
    "# 尝试加载已训练模型\n",
    "try:\n",
    "    model.load_weights(path_checkpoint)\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义early stoping如果3个epoch内validation loss没有改善则停止训练\n",
    "earlystopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 自动降低learning rate\n",
    "lr_reduction = ReduceLROnPlateau(monitor='val_loss',\n",
    "                                       factor=0.1, min_lr=1e-5, patience=0,\n",
    "                                       verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义callback函数\n",
    "callbacks = [\n",
    "    earlystopping,\n",
    "#     checkpoint,\n",
    "    lr_reduction\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\46124\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "Train on 3240 samples, validate on 360 samples\n",
      "Epoch 1/20\n",
      "3240/3240 [==============================] - 21s 6ms/step - loss: 0.6494 - acc: 0.6225 - val_loss: 0.5470 - val_acc: 0.7444\n",
      "Epoch 2/20\n",
      "3240/3240 [==============================] - 18s 6ms/step - loss: 0.4689 - acc: 0.7907 - val_loss: 0.3983 - val_acc: 0.8556\n",
      "Epoch 3/20\n",
      "3240/3240 [==============================] - 18s 5ms/step - loss: 0.4016 - acc: 0.8235 - val_loss: 0.3841 - val_acc: 0.8472\n",
      "Epoch 4/20\n",
      "3240/3240 [==============================] - 18s 5ms/step - loss: 0.3588 - acc: 0.8534 - val_loss: 0.3667 - val_acc: 0.8639\n",
      "Epoch 5/20\n",
      "3240/3240 [==============================] - 18s 5ms/step - loss: 0.3394 - acc: 0.8651 - val_loss: 0.3837 - val_acc: 0.8500\n",
      "\n",
      "Epoch 00005: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.\n",
      "Epoch 6/20\n",
      "3240/3240 [==============================] - 18s 5ms/step - loss: 0.3159 - acc: 0.8762 - val_loss: 0.3229 - val_acc: 0.8722\n",
      "Epoch 7/20\n",
      "3240/3240 [==============================] - 18s 6ms/step - loss: 0.2970 - acc: 0.8873 - val_loss: 0.3135 - val_acc: 0.8778\n",
      "Epoch 8/20\n",
      "3240/3240 [==============================] - 18s 5ms/step - loss: 0.2886 - acc: 0.8932 - val_loss: 0.3077 - val_acc: 0.8861\n",
      "Epoch 9/20\n",
      "3240/3240 [==============================] - 18s 6ms/step - loss: 0.2848 - acc: 0.8948 - val_loss: 0.3032 - val_acc: 0.8944\n",
      "Epoch 10/20\n",
      "3240/3240 [==============================] - 18s 6ms/step - loss: 0.2776 - acc: 0.8969 - val_loss: 0.3055 - val_acc: 0.8917\n",
      "\n",
      "Epoch 00010: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.\n",
      "Epoch 11/20\n",
      "3240/3240 [==============================] - 18s 5ms/step - loss: 0.2775 - acc: 0.8938 - val_loss: 0.3017 - val_acc: 0.8944\n",
      "Epoch 12/20\n",
      "3240/3240 [==============================] - 18s 6ms/step - loss: 0.2730 - acc: 0.8988 - val_loss: 0.2989 - val_acc: 0.8972\n",
      "Epoch 13/20\n",
      "3240/3240 [==============================] - 18s 6ms/step - loss: 0.2709 - acc: 0.9003 - val_loss: 0.2983 - val_acc: 0.8972\n",
      "Epoch 14/20\n",
      "3240/3240 [==============================] - 18s 6ms/step - loss: 0.2706 - acc: 0.9003 - val_loss: 0.2976 - val_acc: 0.8972\n",
      "Epoch 15/20\n",
      "3240/3240 [==============================] - 18s 6ms/step - loss: 0.2698 - acc: 0.9006 - val_loss: 0.2971 - val_acc: 0.8972\n",
      "Epoch 16/20\n",
      "3240/3240 [==============================] - 18s 6ms/step - loss: 0.2691 - acc: 0.9015 - val_loss: 0.2969 - val_acc: 0.8972\n",
      "Epoch 17/20\n",
      "3240/3240 [==============================] - 18s 6ms/step - loss: 0.2689 - acc: 0.9019 - val_loss: 0.2966 - val_acc: 0.8972\n",
      "Epoch 18/20\n",
      "3240/3240 [==============================] - 19s 6ms/step - loss: 0.2685 - acc: 0.9006 - val_loss: 0.2965 - val_acc: 0.8972\n",
      "Epoch 19/20\n",
      "3240/3240 [==============================] - 19s 6ms/step - loss: 0.2676 - acc: 0.9025 - val_loss: 0.2958 - val_acc: 0.8972\n",
      "Epoch 20/20\n",
      "3240/3240 [==============================] - 19s 6ms/step - loss: 0.2669 - acc: 0.9022 - val_loss: 0.2956 - val_acc: 0.8972\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x19165d4fe80>"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 开始训练\n",
    "model.fit(X_train, y_train,\n",
    "          validation_split=0.1, \n",
    "          epochs=20,\n",
    "          batch_size=128,\n",
    "          callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "400/400 [==============================] - 1s 3ms/step\n",
      "Accuracy:87.75%\n"
     ]
    }
   ],
   "source": [
    "result = model.evaluate(X_test, y_test)\n",
    "print('Accuracy:{0:.2%}'.format(result[1]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
